/*
 * Copyright (c) 2023-2023 elsfs Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.elsfs.cloud.spring.authorizationserver.oidc;

import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.OidcScopes;
import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
import org.springframework.security.oauth2.core.oidc.StandardClaimNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationContext;

/**
 * userinfo 端点映射 返回OidcUserInfo 在{@link UserInfoResponseHandler} 里面写入返回给前端
 *
 * @author zeng
 */
public class OidcUserInfoMapper
    implements Function<OidcUserInfoAuthenticationContext, OidcUserInfo> {
  private static final List<String> EMAIL_CLAIMS =
      Arrays.asList(StandardClaimNames.EMAIL, StandardClaimNames.EMAIL_VERIFIED);
  private static final List<String> PHONE_CLAIMS =
      Arrays.asList(StandardClaimNames.PHONE_NUMBER, StandardClaimNames.PHONE_NUMBER_VERIFIED);
  private static final List<String> PROFILE_CLAIMS =
      Arrays.asList(
          StandardClaimNames.NAME,
          StandardClaimNames.FAMILY_NAME,
          StandardClaimNames.GIVEN_NAME,
          StandardClaimNames.MIDDLE_NAME,
          StandardClaimNames.NICKNAME,
          StandardClaimNames.PREFERRED_USERNAME,
          StandardClaimNames.PROFILE,
          StandardClaimNames.PICTURE,
          StandardClaimNames.WEBSITE,
          StandardClaimNames.GENDER,
          StandardClaimNames.BIRTHDATE,
          StandardClaimNames.ZONEINFO,
          StandardClaimNames.LOCALE,
          StandardClaimNames.UPDATED_AT);

  @Override
  public OidcUserInfo apply(OidcUserInfoAuthenticationContext authenticationContext) {
    OAuth2Authorization authorization = authenticationContext.getAuthorization();
    OidcIdToken idToken = authorization.getToken(OidcIdToken.class).getToken();
    OAuth2AccessToken accessToken = authenticationContext.getAccessToken();
    Set<String> scopes = accessToken.getScopes();

    Map<String, Object> scopeRequestedClaims =
        getClaimsRequestedByScope(idToken.getClaims(), accessToken.getScopes());
    scopeRequestedClaims.put(OAuth2ParameterNames.SCOPE, scopes);
    scopeRequestedClaims.put("username", idToken.getClaim(StandardClaimNames.SUB));
    return new OidcUserInfo(scopeRequestedClaims);
  }

  private static Map<String, Object> getClaimsRequestedByScope(
      Map<String, Object> claims, Set<String> requestedScopes) {
    Set<String> scopeRequestedClaimNames = new HashSet<>(32);
    scopeRequestedClaimNames.add(StandardClaimNames.SUB);
    System.out.println(claims);
    if (requestedScopes.contains(OidcScopes.ADDRESS)) {
      scopeRequestedClaimNames.add(StandardClaimNames.ADDRESS);
    }
    if (requestedScopes.contains(OidcScopes.EMAIL)) {
      scopeRequestedClaimNames.addAll(EMAIL_CLAIMS);
    }
    if (requestedScopes.contains(OidcScopes.PHONE)) {
      scopeRequestedClaimNames.addAll(PHONE_CLAIMS);
    }
    if (requestedScopes.contains(OidcScopes.PROFILE)) {
      scopeRequestedClaimNames.addAll(PROFILE_CLAIMS);
    }

    Map<String, Object> requestedClaims = new HashMap<>(claims);
    requestedClaims.keySet().removeIf(claimName -> !scopeRequestedClaimNames.contains(claimName));

    return requestedClaims;
  }
}
